#!/usr/bin/env python3

"""
Compare the key hierarchies in two or more YAML files and report differences.

This is handy for comparing lava framework config files to see if configs for
different environments have drifted apart.

"""

from __future__ import annotations

import argparse
import os
import sys
from collections.abc import Iterable
from fnmatch import fnmatch
from itertools import combinations
from typing import Any

import yaml

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]

# ..............................................................................
# region colour

# ------------------------------------------------------------------------------
# Clunky support for colour output if colorama is not installed.


try:
    # noinspection PyUnresolvedReferences
    import colorama

    # noinspection PyUnresolvedReferences
    from colorama import Fore, Style

    colorama.init()

except ImportError:

    class Fore:
        """Basic alternative to colorama colours using ANSI sequences."""

        RESET = '\033[0m'
        BLACK = '\033[30m'
        RED = '\033[31m'
        GREEN = '\033[32m'
        YELLOW = '\033[33m'
        BLUE = '\033[34m'
        MAGENTA = '\033[35m'
        CYAN = '\033[36m'

    class Style:
        """Basic alternative to colorama styles using ANSI sequences."""

        RESET_ALL = '\033[0m'
        BRIGHT = '\033[1m'
        DIM = '\033[2m'
        NORMAL = '\033[22m'


C0 = Fore.RESET + Style.RESET_ALL
C1 = Fore.GREEN
C2 = Fore.BLUE

# Highlight versions
CC1 = C1 + Style.BRIGHT
CC2 = C2 + Style.BRIGHT


# endregion colour
# ..............................................................................


# ------------------------------------------------------------------------------
def flatten_dict_keys(d: dict[str, Any], parent: str = None) -> set[str]:
    """
    Convert hierarchical dictionary keys into parent.child.child... format.

    :param d:           A dictionary.
    :param parent:      The key name in the parent dictionary.

    :return:            A set of hierarchical key names in the form x.y.z.

    """

    s = set()

    for k, v in d.items():
        current = f'{parent}.{k}' if parent else k
        s.add(current)
        if isinstance(v, dict):
            s |= flatten_dict_keys(v, parent=current)

    return s


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(
        prog=PROG,
        description='Compare key hierarchies in two or more YAML files and report differences.',
        epilog='Exits with non-zero status if any key differences are detected.',
    )

    argp.add_argument(
        '-i',
        '--ignore',
        action='append',
        help=(
            'Ignore keys with a flattened (dot separated hierarchy)'
            'form that match the given glob patten.'
        ),
    )

    argp.add_argument(
        'yaml_file',
        action='store',
        nargs='*',
        help='YAML files. There needs to be at least 2 files for this to be useful.',
    )

    return argp.parse_args()


# ------------------------------------------------------------------------------
def read_yaml(filename: str) -> dict[str, Any]:
    """
    Read a YAML file containing a dictionary.

    :param filename:    YAML file name.

    :return:            The object decoded from the YAML.
    """

    try:
        with open(filename) as fp:
            return yaml.safe_load(fp)
    except Exception as e:
        raise Exception(f'{filename}: {e}')


# ------------------------------------------------------------------------------
def set_exclude_glob_matches(s: set[str], patterns: Iterable[str] = None) -> set[str]:
    """
    Create a new set from a set of strings, excluding any that match any of the glob patterns.

    :param s:           A set of strings.
    :param patterns:    An iterable of exclusion patterns.
    :return:            A new set without any items matching the exclusion
                        patterns.
    """

    if not patterns:
        patterns = set()

    s2 = set()

    for item in s:
        for pat in patterns:
            if fnmatch(item, pat):
                break
        else:
            s2.add(item)

    return s2


# ------------------------------------------------------------------------------
def main() -> int:
    """
    Compare keys in 2 or more YAML files.

    :return:        1 if significant differences, 0 otherwise
    """

    args = process_cli_args()
    if len(args.yaml_file) < 2:
        return 0

    file_list = [(f, flatten_dict_keys(read_yaml(f))) for f in args.yaml_file]
    diff_count = 0

    # Do pairwise comparison of keys in the file
    for item1, item2 in combinations(file_list, 2):
        in_1 = set_exclude_glob_matches(item1[1] - item2[1], args.ignore)
        in_2 = set_exclude_glob_matches(item2[1] - item1[1], args.ignore)

        if not in_1 and not in_2:
            continue

        # We have some differences to reoort.
        if diff_count:
            print(f'{Style.DIM}--------------------{C0}')
        else:
            print()

        diff_count += len(in_1) + len(in_2)

        print(f'{CC1}{item1[0]}:{C0}')
        if in_1:
            print(f'    {C1}{", ".join(sorted(in_1))}{C0}')
        print(f'{CC2}{item2[0]}:{C0}')
        if in_2:
            print(f'    {C2}{", ".join(sorted(in_2))}{C0}')

    if diff_count:
        print()

    return int(diff_count != 0)


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
    except KeyboardInterrupt:
        print('Interrupt', file=sys.stderr)
        exit(2)
